import numpy as np

from typing import List, Optional
from oracles.saddle import ArrayPair, BaseSmoothSaddleOracle, OracleLinearComb
from methods.saddle import Logger
from .base import BaseSaddleMethod
from .constraints import ConstraintsL2


class CentralizedExtragradient(BaseSaddleMethod):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            stepsize: float,
            z_0: ArrayPair,
            logger: Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        self._num_nodes = len(oracles)
        oracle_sum = OracleLinearComb(oracles, [1 / self._num_nodes] * self._num_nodes)
        super().__init__(oracle_sum, z_0, None, None, logger)
        self.oracle_list = oracles
        self.stepsize = stepsize
        self.grad_list = None
        self.constraints = constraints
        self.z = z_0

    def step(self):
        grad_z_list = [oracle.grad(self.z) for oracle in self.oracle_list]
        grad_z = ArrayPair.mean(grad_z_list)
        w = self.z - self.stepsize * grad_z
        grad_w_list = [oracle.grad(w) for oracle in self.oracle_list]
        grad_w = ArrayPair.mean(grad_w_list)
        self.z = self.z - self.stepsize * grad_w
        self.gradient_calls += 2 * self._num_nodes
        self.current_round_volume += 4 * self._num_nodes
